-
Notifications
You must be signed in to change notification settings - Fork 269
[CK_TILE] Fix alignment in Stream-K workspace buffer #3625
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR fixes an alignment issue in the Stream-K workspace buffer that caused stale data to be read due to cache line conflicts between flags and partials data. The fix ensures the flags portion of the buffer is 128-byte aligned to prevent cache coherency issues.
Changes:
- Modified
get_flags_buffer_size()to pad the flags buffer to 128-byte alignment - Added three new unit tests to verify correct buffer sizing for different flag array sizes
- Re-enabled previously disabled Stream-K reduction tests
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated no comments.
Show a summary per file
| File | Description |
|---|---|
| streamk_gemm_tile_partitioner_impl.hpp | Implements 128-byte alignment padding in get_flags_buffer_size() |
| streamk_gemm_tile_partitioner.hpp | Updates documentation for get_flags_buffer_size() to reflect alignment requirement |
| test_streamk_tile_partitioner_common.hpp | Adds three test configuration structs for testing buffer size calculations |
| test_streamk_tile_partitioner.cpp | Adds unit tests for aligned buffer sizing and updates existing test expectations |
| CMakeLists.txt | Re-enables the test_ck_tile_streamk_reduction test suite |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
In CK Tile Stream-K, the workspace buffer is used to hold flags and partials, where the first i bytes holds the flags and the remaining bytes hold partials. This change adds padding to the flags prefix of the workspace buffer to ensure the number of bytes is 128B-aligned. Without this alignment, since workgroups do not skip cache when reading from partials, they may read stale partials data in cache, leading to incorrect results. The added padding avoids the stale data reading. This change also re-enables the test_ck_tile_streamk_reduction tests.
58a8384 to
309c253
Compare
cgmillette
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great!
amd-anclark
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work! The tests are clear, and your detailed explanation with diagrams is extremely helpful. This is exactly the level of quality and clarity we strive for. Thank you!
Proposed changes
Recently, a Stream-K reduction unit test failed; these tests were temporarily disabled in #3559 since the failure was difficult to reproduce (i.e., the test only failed ~once every 8,000-10,000 runs on one machine). After debugging, the issue was narrowed down to an alignment issue in Stream-K's workspace buffer that resulted in stale data being read by a workgroup. See the Discussion section for more details.
Hence, this PR makes the following changes:
get_flags_buffer_sizeclass method.test_ck_tile_streamk_reductionunit tests.Checklist
Please put an
xinto the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.clang-formaton all changed filesDiscussion
The Stream-K workspace buffer is a single buffer where the first partition stores flags for the workgroups and the second partition holds each workgroups partials (i.e., partial results for a macro tile in the output tensor) as shown in the following diagram:

But, in this scenario, there is no guarantee that flags will span an entire cache line. So, we could end up with something like this:

In the Stream-K normal and tree reductions, we use cache modifiers to skip cache in certain cases (see #3371 for details). Workgroups skip cache when reading and writing to flags and when writing to partials. But, the cache is not skipped when reading from partials. Using the example above, when a workgroup reads from flags, the entire cache line, which may contain unfinalized partials data, gets stored in cache. Since workgroups don't skip cache to read from partials, they may end up reading incorrect partials data from cache, leading to incorrect results.
While debugging, I ran various experiments to confirm the alignment issue was the cause. The strongest evidence was as follows:
While one solution is to create separate buffers for partials and flags (rather than a single workspace buffer), this option would involve an interface change. Instead, we opted to pad the flags portion of the workspace buffer to be 128B-aligned since this does not involve any interface changes. Hence, the resulting workspace buffer looks something like this:
